{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Normalization from scratch\n",
    "\n",
    "[Batch Normalization](https://arxiv.org/abs/1502.03167) is another way to avoid overfitting, gradient explosion or elimination."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import mxnet as mx\n",
    "import numpy as np\n",
    "from mxnet import nd, autograd\n",
    "mx.random.seed(1)\n",
    "ctx = mx.gpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The MNIST dataset\n",
    "\n",
    "Let's prepare the data (as always!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "num_inputs = 784\n",
    "num_outputs = 10\n",
    "def transform(data, label):\n",
    "    return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)\n",
    "train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),\n",
    "                                      batch_size, shuffle=True)\n",
    "test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),\n",
    "                                     batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch Normalization layer\n",
    "\n",
    "The layer, unlike Dropout, is usually used **before** the activation layer \n",
    "(according to the authors' original paper), instead of after activation layer.\n",
    "\n",
    "The basic idea is doing the normalization then applying a linear scale and shift to the mini-batch:\n",
    "\n",
    "For input mini-batch $B = \\{x_{1, ..., m}\\}$, we want to learn the parameter $\\gamma$ and $\\beta$.\n",
    "The output of the layer is $\\{y_i = BN_{\\gamma, \\beta}(x_i)\\}$, where:\n",
    "\n",
    "$$\\mu_B \\leftarrow \\frac{1}{m}\\sum_{i = 1}^{m}x_i$$\n",
    "$$\\sigma_B^2 \\leftarrow \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\mu_B)^2$$\n",
    "$$\\hat{x_i} \\leftarrow \\frac{x_i - \\mu_B}{\\sqrt{\\sigma_B^2 + \\epsilon}}$$\n",
    "$$y_i \\leftarrow \\gamma \\hat{x_i} + \\beta \\equiv \\mbox{BN}_{\\gamma,\\beta}(x_i)$$\n",
    "\n",
    "* formulas taken from Ioffe, Sergey, and Christian Szegedy. \"Batch normalization: Accelerating deep network training by reducing internal covariate shift.\" International Conference on Machine Learning. 2015.\n",
    "\n",
    "For the spirit of \"from scratch\", we implement the layer by ourselves,\n",
    "by referencing the formulas from the original paper.\n",
    "\n",
    "Pay attention that, when it comes to (2D) CNN, we normalize `batch_size * height * width` over each channel.\n",
    "So that `gamma` and `beta` have the lengths the same as `channel_count`.\n",
    "In our implementation, we need to manually reshape `gamma` and `beta` \n",
    "so that they could (be automatically broadcast and) multipy the matrices in the desired way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def pure_batch_norm(X, gamma, beta, eps = 1e-5):\n",
    "    if len(X.shape) not in (2, 4):\n",
    "        raise ValueError('only support dense or 2dconv')\n",
    "\n",
    "    # dense\n",
    "    if len(X.shape) == 2:\n",
    "        # mini-batch mean\n",
    "        mean = nd.mean(X, axis=0)\n",
    "        # mini-batch variance\n",
    "        variance = nd.mean((X - mean) ** 2, axis=0)\n",
    "        # normalize\n",
    "        X_hat = (X - mean) * 1.0 / nd.sqrt(variance + eps)\n",
    "        # scale and shift\n",
    "        out = gamma * X_hat + beta\n",
    "    \n",
    "    # 2d conv\n",
    "    elif len(X.shape) == 4:\n",
    "        # extract the dimensions\n",
    "        N, C, H, W = X.shape\n",
    "        # mini-batch mean\n",
    "        mean = nd.mean(X, axis=(0, 2, 3))\n",
    "        # mini-batch variance\n",
    "        variance = nd.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))\n",
    "        # normalize\n",
    "        X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / nd.sqrt(variance.reshape((1, C, 1, 1)) + eps)\n",
    "        # scale and shift\n",
    "        out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))\n",
    "    \n",
    "    return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's do some sanity check. We expect each **column** of the input matrix is normalized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[  1.   7.]\n",
       " [  5.   4.]\n",
       " [  6.  10.]]\n",
       "<NDArray 3x2 @gpu(0)>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A = nd.array([1,7,5,4,6,10], ctx=ctx).reshape((3,2))\n",
    "A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[-1.38872862  0.        ]\n",
       " [ 0.46290955 -1.22474384]\n",
       " [ 0.9258191   1.22474384]]\n",
       "<NDArray 3x2 @gpu(0)>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pure_batch_norm(A,\n",
    "    gamma = nd.array([1,1], ctx=ctx), \n",
    "    beta=nd.array([0,0], ctx=ctx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[[[ 1.  6.]\n",
       "   [ 5.  7.]]\n",
       "\n",
       "  [[ 4.  3.]\n",
       "   [ 2.  5.]]]\n",
       "\n",
       "\n",
       " [[[ 6.  3.]\n",
       "   [ 2.  4.]]\n",
       "\n",
       "  [[ 5.  3.]\n",
       "   [ 2.  5.]]]]\n",
       "<NDArray 2x2x2x2 @gpu(0)>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ga = nd.array([1,1], ctx=ctx)\n",
    "be = nd.array([0,0], ctx=ctx)\n",
    "\n",
    "B = nd.array([1,6,5,7,4,3,2,5,6,3,2,4,5,3,2,5,6], ctx=ctx).reshape((2,2,2,2))\n",
    "B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[[[-1.63784397  0.88191599]\n",
       "   [ 0.37796399  1.38586795]]\n",
       "\n",
       "  [[ 0.30779248 -0.51298743]\n",
       "   [-1.33376741  1.12857234]]]\n",
       "\n",
       "\n",
       " [[[ 0.88191599 -0.62993997]\n",
       "   [-1.13389194 -0.12598799]]\n",
       "\n",
       "  [[ 1.12857234 -0.51298743]\n",
       "   [-1.33376741  1.12857234]]]]\n",
       "<NDArray 2x2x2x2 @gpu(0)>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pure_batch_norm(B, ga, be)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It seems that we implement it correctly.\n",
    "\n",
    "Note: Batch Normalization's **backward** pass is a little bit tricky. But we are not covering it here, because the `autograd` from `mxnet` automatically figures it out!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Besides that, in the testing process, we want to use the mean and variance of the **complete dataset**, instead of those of **mini batches**. In the implementation, we use moving statistics as a trade off, because we don't want to or don't have the ability to compute the statistics of the complete dataset (in the second loop).\n",
    "\n",
    "Then here comes another concern: we need to maintain the moving statistics **along with multiple runs of the BN**. It's an engineering issue rather than a deep/machine learning issue. On the one hand, the moving statistics are similar to `gamma` and `beta`; on the other hand, they are **not** updated by the gradient backwards. In this quick-and-dirty implementation, we use the global dictionary variables to store the statistics, in which each key is the name of the layer (`scope_name`), and the value is the statistics. (**Attention**: always be very careful if you have to use global variables!) Moreover, we have another parameter `is_training` to indicate whether we are doing training or testing.\n",
    "\n",
    "Now we are ready to define our complete `batch_norm()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def batch_norm(X,\n",
    "               gamma,\n",
    "               beta,\n",
    "               momentum = 0.9,\n",
    "               eps = 1e-5,\n",
    "               scope_name = '',\n",
    "               is_training = True,\n",
    "               debug = False):\n",
    "    \"\"\"compute the batch norm \"\"\"\n",
    "    global _BN_MOVING_MEANS, _BN_MOVING_VARS\n",
    "    \n",
    "    #########################\n",
    "    # the usual batch norm transformation\n",
    "    #########################\n",
    "    \n",
    "    if len(X.shape) not in (2, 4):\n",
    "        raise ValueError('the input data shape should be one of:\\n' + \n",
    "                         'dense: (batch size, # of features)\\n' + \n",
    "                         '2d conv: (batch size, # of features, height, width)'\n",
    "                        )\n",
    "    \n",
    "    # dense\n",
    "    if len(X.shape) == 2:\n",
    "        # mini-batch mean\n",
    "        mean = nd.mean(X, axis=0)\n",
    "        # mini-batch variance\n",
    "        variance = nd.mean((X - mean) ** 2, axis=0)\n",
    "        # normalize\n",
    "        if is_training:\n",
    "            # while training, we normalize the data using its mean and variance\n",
    "            X_hat = (X - mean) * 1.0 / nd.sqrt(variance + eps)\n",
    "        else:\n",
    "            # while testing, we normalize the data using the pre-computed mean and variance\n",
    "            X_hat = (X - _BN_MOVING_MEANS[scope_name]) *1.0 / nd.sqrt(_BN_MOVING_VARS[scope_name] + eps)\n",
    "        # scale and shift\n",
    "        out = gamma * X_hat + beta\n",
    "    \n",
    "    # 2d conv\n",
    "    elif len(X.shape) == 4:\n",
    "        # extract the dimensions\n",
    "        N, C, H, W = X.shape\n",
    "        # mini-batch mean\n",
    "        mean = nd.mean(X, axis=(0,2,3))\n",
    "        # mini-batch variance\n",
    "        variance = nd.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))\n",
    "        # normalize\n",
    "        X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / nd.sqrt(variance.reshape((1, C, 1, 1)) + eps)\n",
    "        if is_training:\n",
    "            # while training, we normalize the data using its mean and variance\n",
    "            X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / nd.sqrt(variance.reshape((1, C, 1, 1)) + eps)\n",
    "        else:\n",
    "            # while testing, we normalize the data using the pre-computed mean and variance\n",
    "            X_hat = (X - _BN_MOVING_MEANS[scope_name].reshape((1, C, 1, 1))) * 1.0 \\\n",
    "                / nd.sqrt(_BN_MOVING_VARS[scope_name].reshape((1, C, 1, 1)) + eps)\n",
    "        # scale and shift\n",
    "        out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))\n",
    "      \n",
    "    #########################\n",
    "    # to keep the moving statistics\n",
    "    #########################\n",
    "    \n",
    "    # init the attributes\n",
    "    try: # to access them\n",
    "        _BN_MOVING_MEANS, _BN_MOVING_VARS\n",
    "    except: # error, create them\n",
    "        _BN_MOVING_MEANS, _BN_MOVING_VARS = {}, {}\n",
    "    \n",
    "    # store the moving statistics by their scope_names, inplace    \n",
    "    if scope_name not in _BN_MOVING_MEANS:\n",
    "        _BN_MOVING_MEANS[scope_name] = mean\n",
    "    else:\n",
    "        _BN_MOVING_MEANS[scope_name] = _BN_MOVING_MEANS[scope_name] * momentum + mean * (1.0 - momentum)\n",
    "    if scope_name not in _BN_MOVING_VARS:\n",
    "        _BN_MOVING_VARS[scope_name] = variance\n",
    "    else:\n",
    "        _BN_MOVING_VARS[scope_name] = _BN_MOVING_VARS[scope_name] * momentum + variance * (1.0 - momentum)\n",
    "        \n",
    "    #########################\n",
    "    # debug info\n",
    "    #########################\n",
    "    if debug:\n",
    "        print('== info start ==')\n",
    "        print('scope_name = {}'.format(scope_name))\n",
    "        print('mean = {}'.format(mean))\n",
    "        print('var = {}'.format(variance))\n",
    "        print('_BN_MOVING_MEANS = {}'.format(_BN_MOVING_MEANS[scope_name]))\n",
    "        print('_BN_MOVING_VARS = {}'.format(_BN_MOVING_VARS[scope_name]))\n",
    "        print('output = {}'.format(out))\n",
    "        print('== info end ==')\n",
    " \n",
    "    #########################\n",
    "    # return\n",
    "    #########################\n",
    "    return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters and gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#######################\n",
    "#  Set the scale for weight initialization and choose \n",
    "#  the number of hidden units in the fully-connected layer \n",
    "####################### \n",
    "weight_scale = .01\n",
    "num_fc = 128\n",
    "\n",
    "W1 = nd.random_normal(shape=(20, 1, 3,3), scale=weight_scale, ctx=ctx) \n",
    "b1 = nd.random_normal(shape=20, scale=weight_scale, ctx=ctx)\n",
    "\n",
    "gamma1 = nd.random_normal(shape=20, loc=1, scale=weight_scale, ctx=ctx)\n",
    "beta1 = nd.random_normal(shape=20, scale=weight_scale, ctx=ctx)\n",
    "\n",
    "W2 = nd.random_normal(shape=(50, 20, 5, 5), scale=weight_scale, ctx=ctx)\n",
    "b2 = nd.random_normal(shape=50, scale=weight_scale, ctx=ctx)\n",
    "\n",
    "gamma2 = nd.random_normal(shape=50, loc=1, scale=weight_scale, ctx=ctx)\n",
    "beta2 = nd.random_normal(shape=50, scale=weight_scale, ctx=ctx)\n",
    "\n",
    "W3 = nd.random_normal(shape=(800, num_fc), scale=weight_scale, ctx=ctx)\n",
    "b3 = nd.random_normal(shape=num_fc, scale=weight_scale, ctx=ctx)\n",
    "\n",
    "gamma3 = nd.random_normal(shape=num_fc, loc=1, scale=weight_scale, ctx=ctx)\n",
    "beta3 = nd.random_normal(shape=num_fc, scale=weight_scale, ctx=ctx)\n",
    "\n",
    "W4 = nd.random_normal(shape=(num_fc, num_outputs), scale=weight_scale, ctx=ctx)\n",
    "b4 = nd.random_normal(shape=10, scale=weight_scale, ctx=ctx)\n",
    "\n",
    "params = [W1, b1, gamma1, beta1, W2, b2, gamma2, beta2, W3, b3, gamma3, beta3, W4, b4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for param in params:\n",
    "    param.attach_grad()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Activation fucntions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def relu(X):\n",
    "    return nd.maximum(X, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Softmax output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def softmax(y_linear):\n",
    "    exp = nd.exp(y_linear-nd.max(y_linear))\n",
    "    partition = nd.nansum(exp, axis=0, exclude=True).reshape((-1,1))\n",
    "    return exp / partition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The *softmax* cross-entropy loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def softmax_cross_entropy(yhat_linear, y):\n",
    "    return - nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model\n",
    "\n",
    "We insert the BN layer right after each linear layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def net(X, is_training = True, debug=False):\n",
    "    ########################\n",
    "    #  Define the computation of the first convolutional layer\n",
    "    ########################\n",
    "    h1_conv = nd.Convolution(data=X, weight=W1, bias=b1, kernel=(3,3), num_filter=20)\n",
    "    h1_normed = batch_norm(h1_conv, gamma1, beta1, scope_name='bn1', is_training=is_training)\n",
    "    h1_activation = relu(h1_normed)\n",
    "    h1 = nd.Pooling(data=h1_activation, pool_type=\"avg\", kernel=(2,2), stride=(2,2))\n",
    "    if debug:\n",
    "        print(\"h1 shape: %s\" % (np.array(h1.shape)))\n",
    "        \n",
    "    ########################\n",
    "    #  Define the computation of the second convolutional layer\n",
    "    ########################\n",
    "    h2_conv = nd.Convolution(data=h1, weight=W2, bias=b2, kernel=(5,5), num_filter=50)\n",
    "    h2_normed = batch_norm(h2_conv, gamma2, beta2, scope_name='bn2', is_training=is_training)\n",
    "    h2_activation = relu(h2_normed)\n",
    "    h2 = nd.Pooling(data=h2_activation, pool_type=\"avg\", kernel=(2,2), stride=(2,2))\n",
    "    if debug:\n",
    "        print(\"h2 shape: %s\" % (np.array(h2.shape)))\n",
    "    \n",
    "    ########################\n",
    "    #  Flattening h2 so that we can feed it into a fully-connected layer\n",
    "    ########################\n",
    "    h2 = nd.flatten(h2)\n",
    "    if debug:\n",
    "        print(\"Flat h2 shape: %s\" % (np.array(h2.shape)))\n",
    "    \n",
    "    ########################\n",
    "    #  Define the computation of the third (fully-connected) layer\n",
    "    ########################\n",
    "    h3_linear = nd.dot(h2, W3) + b3\n",
    "    h3_normed = batch_norm(h3_linear, gamma3, beta3, scope_name='bn3', is_training=is_training)\n",
    "    h3 = relu(h3_normed)\n",
    "    if debug:\n",
    "        print(\"h3 shape: %s\" % (np.array(h3.shape)))\n",
    "        \n",
    "    ########################\n",
    "    #  Define the computation of the output layer\n",
    "    ########################\n",
    "    yhat_linear = nd.dot(h3, W4) + b4\n",
    "    if debug:\n",
    "        print(\"yhat_linear shape: %s\" % (np.array(yhat_linear.shape)))\n",
    "    \n",
    "    return yhat_linear\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test run\n",
    "\n",
    "Can data be passed into the `net()`?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for data, _ in train_data:\n",
    "    data = data.as_in_context(ctx)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "h1 shape: [64 20 13 13]\n",
      "h2 shape: [64 50  4  4]\n",
      "Flat h2 shape: [ 64 800]\n",
      "h3 shape: [ 64 128]\n",
      "yhat_linear shape: [64 10]\n"
     ]
    }
   ],
   "source": [
    "output = net(data, is_training=True, debug=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def SGD(params, lr):    \n",
    "    for param in params:\n",
    "        param[:] = param - lr * param.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iterator, net):\n",
    "    numerator = 0.\n",
    "    denominator = 0.\n",
    "    for i, (data, label) in enumerate(data_iterator):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        label_one_hot = nd.one_hot(label, 10)\n",
    "        output = net(data, is_training=False) # attention here!\n",
    "        predictions = nd.argmax(output, axis=1)\n",
    "        numerator += nd.sum(predictions == label)\n",
    "        denominator += data.shape[0]\n",
    "    return (numerator / denominator).asscalar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Execute the training loop\n",
    "\n",
    "Note: you may want to use a gpu to run the code below. (And remember to set the `ctx = mx.gpu()` accordingly in the very beginning of this article.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0. Loss: 0.060423938439, Train_acc 0.988017, Test_acc 0.9883\n",
      "Epoch 1. Loss: 0.0381374779348, Train_acc 0.993167, Test_acc 0.9913\n",
      "Epoch 2. Loss: 0.03584648306, Train_acc 0.994733, Test_acc 0.9926\n",
      "Epoch 3. Loss: 0.0266538889392, Train_acc 0.995667, Test_acc 0.9927\n",
      "Epoch 4. Loss: 0.0211588057153, Train_acc 0.996267, Test_acc 0.9929\n",
      "Epoch 5. Loss: 0.0195035166958, Train_acc 0.996533, Test_acc 0.9926\n",
      "Epoch 6. Loss: 0.0182358424663, Train_acc 0.997883, Test_acc 0.9935\n",
      "Epoch 7. Loss: 0.0131261121093, Train_acc 0.9978, Test_acc 0.9936\n",
      "Epoch 8. Loss: 0.0127412964013, Train_acc 0.998217, Test_acc 0.9925\n",
      "Epoch 9. Loss: 0.0101101914425, Train_acc 0.998967, Test_acc 0.9937\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "moving_loss = 0.\n",
    "learning_rate = .001\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, (data, label) in enumerate(train_data):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        label_one_hot = nd.one_hot(label, num_outputs)\n",
    "        with autograd.record():\n",
    "            # we are in training process,\n",
    "            # so we normalize the data using batch mean and variance\n",
    "            output = net(data, is_training=True)\n",
    "            loss = softmax_cross_entropy(output, label_one_hot)\n",
    "        loss.backward()\n",
    "        SGD(params, learning_rate)\n",
    "        \n",
    "        ##########################\n",
    "        #  Keep a moving average of the losses\n",
    "        ##########################\n",
    "        if i == 0:\n",
    "            moving_loss = nd.mean(loss).asscalar()\n",
    "        else:\n",
    "            moving_loss = .99 * moving_loss + .01 * nd.mean(loss).asscalar()\n",
    "            \n",
    "    test_accuracy = evaluate_accuracy(test_data, net)\n",
    "    train_accuracy = evaluate_accuracy(train_data, net)\n",
    "    print(\"Epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" % (e, moving_loss, train_accuracy, test_accuracy)) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Conclusion\n",
    "\n",
    "Compared with the pure mlp, dropout and pure cnn results, we achieve about 99% accuracy on this task **even just after the second epoch**, with just two hidden conv layers and one hidden dense layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "[Batch normalization with gluon](../chapter04_convolutional-neural-networks/cnn-batch-norm-gluon.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
